# -*- coding: utf-8 -*-
import pandas as pd
import applications.pycube.models.alpha_constant as alpha_constant
import json

def get_base_data(alpha_db,start_date, end_date):
    """
    根据日期获取策略基本数据
    :param start_date: 起始时间
    :param end_date: 结束时间
    :return: df_obj
    """
    stockdata = alpha_db.executesql("select * from alpha_stock_data where tradedate >= '" + start_date + "' and tradedate <= '" + end_date + "'")
    df_obj = pd.DataFrame(data=stockdata, columns=alpha_constant.stock_data_column_list)
    return df_obj

def get_alpha_hs300_weight(alpha_db,tradedate):
    return alpha_db.executesql("select tradecode,weight from alpha_hs300_weight where date = '" + tradedate + "'")

def get_hs300_data(alpha_db):
    return alpha_db.executesql("select * from alpha_hs300_data")

def get_stock_industry(alpha_db):
    return alpha_db.executesql("select * from alpha_stock_industry")

def get_st_tradecode(alpha_db):
    return alpha_db.executesql("SELECT tradedate,tradecode FROM alpha_st_tradecode")

def get_trade_create_data(alpha_db):
    return alpha_db.executesql("SELECT tradecode,ipo_date FROM alpha_trade_create_data")

def get_policy(alpha_db,**arg_dict):
    if arg_dict.has_key('policy_id'):
        sql = "select id,policy_name,create_date,select_factor_json,predict_json,style_factors,industry_factors,alpha_factors,bata_factors,comb_weights,create_date,all_factors,end_date from alpha_policy where id=" + \
              arg_dict['policy_id']
    elif arg_dict.has_key('policy_name'):
        sql = "select id,policy_name,create_date,select_factor_json,predict_json,style_factors,industry_factors,alpha_factors,bata_factors,comb_weights,create_date,all_factors,end_date from alpha_policy where policy_name = '" + \
              arg_dict['policy_name'] + "'"
    else:
        return None
    policy_list = alpha_db.executesql(sql)
    policy_dict = {}
    policy_dict['id'] = policy_list[0][0]
    policy_dict['policy_name'] = policy_list[0][1]
    policy_dict['create_date'] = policy_list[0][2]
    policy_dict['select_factor_json'] = policy_list[0][3] is None and '[]' or json.loads(policy_list[0][3])
    policy_dict['predict_json'] = policy_list[0][4] is None and '[]' or json.loads(policy_list[0][4])
    policy_dict['style_factors'] = policy_list[0][5] is None and '[]' or json.loads(policy_list[0][5])
    policy_dict['industry_factors'] = policy_list[0][6] is None and '[]' or json.loads(policy_list[0][6])
    policy_dict['alpha_factors'] = policy_list[0][7] is None and '[]' or json.loads(policy_list[0][7])
    policy_dict['bata_factors'] = policy_list[0][8] is None and '[]' or json.loads(policy_list[0][8])
    policy_dict['comb_weights'] = policy_list[0][9] is None and '[]' or json.loads(policy_list[0][9])
    policy_dict['create_date'] = policy_list[0][10]
    policy_dict['all_factors'] = policy_list[0][11] is None and '[]' or json.loads(policy_list[0][11])
    policy_dict['end_date'] = policy_list[0][12]
    return policy_dict

def get_bench_industry_weight(alpha_db,tradedate):
    """
    计算 沪深300 行业权重
    """
    bench_industry_weight = alpha_db.executesql("select t1.industry,sum(t2.weight) as bench_industry_weight from alpha_stock_industry t1 inner join alpha_hs300_weight t2 on t1.tradecode = t2.tradecode where t2.date = '" + tradedate + "' group by t1.industry")
    bench_industry_weight_df = pd.DataFrame(data=bench_industry_weight, columns=['industry', 'bench_industry_weight'])
    return bench_industry_weight_df

if __name__ == '__main__':
    # dao = AlphaDao('asd')
    # dao.test()
    pass